{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThe book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!pip install keras keras-hub --upgrade -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "import os\n",
    "from IPython.core.magic import register_cell_magic\n",
    "\n",
    "@register_cell_magic\n",
    "def backend(line, cell):\n",
    "    current, required = os.environ.get(\"KERAS_BACKEND\", \"\"), line.split()[-1]\n",
    "    if current == required:\n",
    "        get_ipython().run_cell(cell)\n",
    "    else:\n",
    "        print(\n",
    "            f\"This cell requires the {required} backend. To run it, change KERAS_BACKEND to \"\n",
    "            f\"\\\"{required}\\\" at the top of the notebook, restart the runtime, and rerun the notebook.\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Image generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Deep learning for image generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Sampling from latent spaces of images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Variational autoencoders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Implementing a VAE with Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "from keras import layers\n",
    "\n",
    "latent_dim = 2\n",
    "\n",
    "image_inputs = keras.Input(shape=(28, 28, 1))\n",
    "x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(\n",
    "    image_inputs\n",
    ")\n",
    "x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
    "x = layers.Flatten()(x)\n",
    "x = layers.Dense(16, activation=\"relu\")(x)\n",
    "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
    "encoder = keras.Model(image_inputs, [z_mean, z_log_var], name=\"encoder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "encoder.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import ops\n",
    "\n",
    "class Sampler(keras.Layer):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.seed_generator = keras.random.SeedGenerator()\n",
    "        self.built = True\n",
    "\n",
    "    def call(self, z_mean, z_log_var):\n",
    "        batch_size = ops.shape(z_mean)[0]\n",
    "        z_size = ops.shape(z_mean)[1]\n",
    "        epsilon = keras.random.normal(\n",
    "            (batch_size, z_size), seed=self.seed_generator\n",
    "        )\n",
    "        return z_mean + ops.exp(0.5 * z_log_var) * epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "latent_inputs = keras.Input(shape=(latent_dim,))\n",
    "x = layers.Dense(7 * 7 * 64, activation=\"relu\")(latent_inputs)\n",
    "x = layers.Reshape((7, 7, 64))(x)\n",
    "x = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")(\n",
    "    x\n",
    ")\n",
    "x = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")(\n",
    "    x\n",
    ")\n",
    "decoder_outputs = layers.Conv2D(1, 3, activation=\"sigmoid\", padding=\"same\")(x)\n",
    "decoder = keras.Model(latent_inputs, decoder_outputs, name=\"decoder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "decoder.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class VAE(keras.Model):\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.sampler = Sampler()\n",
    "        self.reconstruction_loss_tracker = keras.metrics.Mean(\n",
    "            name=\"reconstruction_loss\"\n",
    "        )\n",
    "        self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n",
    "\n",
    "    def call(self, inputs):\n",
    "        return self.encoder(inputs)\n",
    "\n",
    "    def compute_loss(self, x, y, y_pred, sample_weight=None, training=True):\n",
    "        original = x\n",
    "        z_mean, z_log_var = y_pred\n",
    "        reconstruction = self.decoder(self.sampler(z_mean, z_log_var))\n",
    "\n",
    "        reconstruction_loss = ops.mean(\n",
    "            ops.sum(\n",
    "                keras.losses.binary_crossentropy(x, reconstruction), axis=(1, 2)\n",
    "            )\n",
    "        )\n",
    "        kl_loss = -0.5 * (\n",
    "            1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)\n",
    "        )\n",
    "        total_loss = reconstruction_loss + ops.mean(kl_loss)\n",
    "\n",
    "        self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
    "        self.kl_loss_tracker.update_state(kl_loss)\n",
    "        return total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
    "mnist_digits = np.concatenate([x_train, x_test], axis=0)\n",
    "mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255\n",
    "\n",
    "vae = VAE(encoder, decoder)\n",
    "vae.compile(optimizer=keras.optimizers.Adam())\n",
    "vae.fit(mnist_digits, epochs=30, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "n = 30\n",
    "digit_size = 28\n",
    "figure = np.zeros((digit_size * n, digit_size * n))\n",
    "\n",
    "grid_x = np.linspace(-1, 1, n)\n",
    "grid_y = np.linspace(-1, 1, n)[::-1]\n",
    "\n",
    "for i, yi in enumerate(grid_y):\n",
    "    for j, xi in enumerate(grid_x):\n",
    "        z_sample = np.array([[xi, yi]])\n",
    "        x_decoded = vae.decoder.predict(z_sample)\n",
    "        digit = x_decoded[0].reshape(digit_size, digit_size)\n",
    "        figure[\n",
    "            i * digit_size : (i + 1) * digit_size,\n",
    "            j * digit_size : (j + 1) * digit_size,\n",
    "        ] = digit\n",
    "\n",
    "plt.figure(figsize=(15, 15))\n",
    "start_range = digit_size // 2\n",
    "end_range = n * digit_size + start_range\n",
    "pixel_range = np.arange(start_range, end_range, digit_size)\n",
    "sample_range_x = np.round(grid_x, 1)\n",
    "sample_range_y = np.round(grid_y, 1)\n",
    "plt.xticks(pixel_range, sample_range_x)\n",
    "plt.yticks(pixel_range, sample_range_y)\n",
    "plt.xlabel(\"z[0]\")\n",
    "plt.ylabel(\"z[1]\")\n",
    "plt.axis(\"off\")\n",
    "plt.imshow(figure, cmap=\"Greys_r\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Diffusion models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### The Oxford Flowers dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "fpath = keras.utils.get_file(\n",
    "    origin=\"https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz\",\n",
    "    extract=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "image_size = 128\n",
    "images_dir = os.path.join(fpath, \"jpg\")\n",
    "dataset = keras.utils.image_dataset_from_directory(\n",
    "    images_dir,\n",
    "    labels=None,\n",
    "    image_size=(image_size, image_size),\n",
    "    crop_to_aspect_ratio=True,\n",
    ")\n",
    "dataset = dataset.rebatch(\n",
    "    batch_size,\n",
    "    drop_remainder=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "for batch in dataset:\n",
    "    img = batch.numpy()[0]\n",
    "    break\n",
    "plt.imshow(img.astype(\"uint8\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### A U-Net denoising autoencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def residual_block(x, width):\n",
    "    input_width = x.shape[3]\n",
    "    if input_width == width:\n",
    "        residual = x\n",
    "    else:\n",
    "        residual = layers.Conv2D(width, 1)(x)\n",
    "    x = layers.BatchNormalization(center=False, scale=False)(x)\n",
    "    x = layers.Conv2D(width, 3, padding=\"same\", activation=\"swish\")(x)\n",
    "    x = layers.Conv2D(width, 3, padding=\"same\")(x)\n",
    "    x = x + residual\n",
    "    return x\n",
    "\n",
    "def get_model(image_size, widths, block_depth):\n",
    "    noisy_images = keras.Input(shape=(image_size, image_size, 3))\n",
    "    noise_rates = keras.Input(shape=(1, 1, 1))\n",
    "\n",
    "    x = layers.Conv2D(widths[0], 1)(noisy_images)\n",
    "    n = layers.UpSampling2D(image_size, interpolation=\"nearest\")(noise_rates)\n",
    "    x = layers.Concatenate()([x, n])\n",
    "\n",
    "    skips = []\n",
    "    for width in widths[:-1]:\n",
    "        for _ in range(block_depth):\n",
    "            x = residual_block(x, width)\n",
    "            skips.append(x)\n",
    "        x = layers.AveragePooling2D(pool_size=2)(x)\n",
    "\n",
    "    for _ in range(block_depth):\n",
    "        x = residual_block(x, widths[-1])\n",
    "\n",
    "    for width in reversed(widths[:-1]):\n",
    "        x = layers.UpSampling2D(size=2, interpolation=\"bilinear\")(x)\n",
    "        for _ in range(block_depth):\n",
    "            x = layers.Concatenate()([x, skips.pop()])\n",
    "            x = residual_block(x, width)\n",
    "\n",
    "    pred_noise_masks = layers.Conv2D(3, 1, kernel_initializer=\"zeros\")(x)\n",
    "\n",
    "    return keras.Model([noisy_images, noise_rates], pred_noise_masks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### The concepts of diffusion time and diffusion schedule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def diffusion_schedule(\n",
    "    diffusion_times,\n",
    "    min_signal_rate=0.02,\n",
    "    max_signal_rate=0.95,\n",
    "):\n",
    "    start_angle = ops.cast(ops.arccos(max_signal_rate), \"float32\")\n",
    "    end_angle = ops.cast(ops.arccos(min_signal_rate), \"float32\")\n",
    "    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)\n",
    "    signal_rates = ops.cos(diffusion_angles)\n",
    "    noise_rates = ops.sin(diffusion_angles)\n",
    "    return noise_rates, signal_rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "diffusion_times = ops.arange(0.0, 1.0, 0.01)\n",
    "noise_rates, signal_rates = diffusion_schedule(diffusion_times)\n",
    "\n",
    "diffusion_times = ops.convert_to_numpy(diffusion_times)\n",
    "noise_rates = ops.convert_to_numpy(noise_rates)\n",
    "signal_rates = ops.convert_to_numpy(signal_rates)\n",
    "\n",
    "plt.plot(diffusion_times, noise_rates, label=\"Noise rate\")\n",
    "plt.plot(diffusion_times, signal_rates, label=\"Signal rate\")\n",
    "\n",
    "plt.xlabel(\"Diffusion time\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### The training process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class DiffusionModel(keras.Model):\n",
    "    def __init__(self, image_size, widths, block_depth, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.image_size = image_size\n",
    "        self.denoising_model = get_model(image_size, widths, block_depth)\n",
    "        self.seed_generator = keras.random.SeedGenerator()\n",
    "        self.loss = keras.losses.MeanAbsoluteError()\n",
    "        self.normalizer = keras.layers.Normalization()\n",
    "\n",
    "    def denoise(self, noisy_images, noise_rates, signal_rates):\n",
    "        pred_noise_masks = self.denoising_model([noisy_images, noise_rates])\n",
    "        pred_images = (\n",
    "            noisy_images - noise_rates * pred_noise_masks\n",
    "        ) / signal_rates\n",
    "        return pred_images, pred_noise_masks\n",
    "\n",
    "    def call(self, images):\n",
    "        images = self.normalizer(images)\n",
    "        noise_masks = keras.random.normal(\n",
    "            (batch_size, self.image_size, self.image_size, 3),\n",
    "            seed=self.seed_generator,\n",
    "        )\n",
    "        diffusion_times = keras.random.uniform(\n",
    "            (batch_size, 1, 1, 1),\n",
    "            minval=0.0,\n",
    "            maxval=1.0,\n",
    "            seed=self.seed_generator,\n",
    "        )\n",
    "        noise_rates, signal_rates = diffusion_schedule(diffusion_times)\n",
    "        noisy_images = signal_rates * images + noise_rates * noise_masks\n",
    "        pred_images, pred_noise_masks = self.denoise(\n",
    "            noisy_images, noise_rates, signal_rates\n",
    "        )\n",
    "        return pred_images, pred_noise_masks, noise_masks\n",
    "\n",
    "    def compute_loss(self, x, y, y_pred, sample_weight=None, training=True):\n",
    "        _, pred_noise_masks, noise_masks = y_pred\n",
    "        return self.loss(noise_masks, pred_noise_masks)\n",
    "\n",
    "    def generate(self, num_images, diffusion_steps):\n",
    "        noisy_images = keras.random.normal(\n",
    "            (num_images, self.image_size, self.image_size, 3),\n",
    "            seed=self.seed_generator,\n",
    "        )\n",
    "        step_size = 1.0 / diffusion_steps\n",
    "        for step in range(diffusion_steps):\n",
    "            diffusion_times = ops.ones((num_images, 1, 1, 1)) - step * step_size\n",
    "            noise_rates, signal_rates = diffusion_schedule(diffusion_times)\n",
    "            pred_images, pred_noises = self.denoise(\n",
    "                noisy_images, noise_rates, signal_rates\n",
    "            )\n",
    "            next_diffusion_times = diffusion_times - step_size\n",
    "            next_noise_rates, next_signal_rates = diffusion_schedule(\n",
    "                next_diffusion_times\n",
    "            )\n",
    "            noisy_images = (\n",
    "                next_signal_rates * pred_images + next_noise_rates * pred_noises\n",
    "            )\n",
    "        images = (\n",
    "            self.normalizer.mean + pred_images * self.normalizer.variance**0.5\n",
    "        )\n",
    "        return ops.clip(images, 0.0, 255.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### The generation process"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Visualizing results with a custom callback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class VisualizationCallback(keras.callbacks.Callback):\n",
    "    def __init__(self, diffusion_steps=20, num_rows=3, num_cols=6):\n",
    "        self.diffusion_steps = diffusion_steps\n",
    "        self.num_rows = num_rows\n",
    "        self.num_cols = num_cols\n",
    "\n",
    "    def on_epoch_end(self, epoch=None, logs=None):\n",
    "        generated_images = self.model.generate(\n",
    "            num_images=self.num_rows * self.num_cols,\n",
    "            diffusion_steps=self.diffusion_steps,\n",
    "        )\n",
    "\n",
    "        plt.figure(figsize=(self.num_cols * 2.0, self.num_rows * 2.0))\n",
    "        for row in range(self.num_rows):\n",
    "            for col in range(self.num_cols):\n",
    "                i = row * self.num_cols + col\n",
    "                plt.subplot(self.num_rows, self.num_cols, i + 1)\n",
    "                img = ops.convert_to_numpy(generated_images[i]).astype(\"uint8\")\n",
    "                plt.imshow(img)\n",
    "                plt.axis(\"off\")\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### It's go time!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = DiffusionModel(image_size, widths=[32, 64, 96, 128], block_depth=2)\n",
    "model.normalizer.adapt(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=keras.optimizers.AdamW(\n",
    "        learning_rate=keras.optimizers.schedules.InverseTimeDecay(\n",
    "            initial_learning_rate=1e-3,\n",
    "            decay_steps=1000,\n",
    "            decay_rate=0.1,\n",
    "        ),\n",
    "        use_ema=True,\n",
    "        ema_overwrite_frequency=100,\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    dataset,\n",
    "    epochs=100,\n",
    "    callbacks=[\n",
    "        VisualizationCallback(),\n",
    "        keras.callbacks.ModelCheckpoint(\n",
    "            filepath=\"diffusion_model.weights.h5\",\n",
    "            save_weights_only=True,\n",
    "            save_best_only=True,\n",
    "        ),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Text-to-image models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "if keras.config.backend() == \"torch\":\n",
    "    # The rest of this chapter will not do any training. The following keeps\n",
    "    # PyTorch from using too much memory by disabling gradients. TensorFlow and\n",
    "    # JAX use a much smaller memory footprint and do not need this hack.\n",
    "    import torch\n",
    "\n",
    "    torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras_hub\n",
    "\n",
    "height, width = 512, 512\n",
    "task = keras_hub.models.TextToImage.from_preset(\n",
    "    \"stable_diffusion_3_medium\",\n",
    "    image_shape=(height, width, 3),\n",
    "    dtype=\"float16\",\n",
    ")\n",
    "prompt = \"A NASA astraunaut riding an origami elephant in New York City\"\n",
    "task.generate(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "task.generate(\n",
    "    {\n",
    "        \"prompts\": prompt,\n",
    "        \"negative_prompts\": \"blue color\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "def display(images):\n",
    "    return Image.fromarray(np.concatenate(images, axis=1))\n",
    "\n",
    "display([task.generate(prompt, num_steps=x) for x in [5, 10, 15, 20, 25]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Exploring the latent space of a text-to-image model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import random\n",
    "\n",
    "def get_text_embeddings(prompt):\n",
    "    token_ids = task.preprocessor.generate_preprocess([prompt])\n",
    "    negative_token_ids = task.preprocessor.generate_preprocess([\"\"])\n",
    "    return task.backbone.encode_text_step(token_ids, negative_token_ids)\n",
    "\n",
    "def denoise_with_text_embeddings(embeddings, num_steps=28, guidance_scale=7.0):\n",
    "    latents = random.normal((1, height // 8, width // 8, 16))\n",
    "    for step in range(num_steps):\n",
    "        latents = task.backbone.denoise_step(\n",
    "            latents,\n",
    "            embeddings,\n",
    "            step,\n",
    "            num_steps,\n",
    "            guidance_scale,\n",
    "        )\n",
    "    return task.backbone.decode_step(latents)[0]\n",
    "\n",
    "def scale_output(x):\n",
    "    x = ops.convert_to_numpy(x)\n",
    "    x = np.clip((x + 1.0) / 2.0, 0.0, 1.0)\n",
    "    return np.round(x * 255.0).astype(\"uint8\")\n",
    "\n",
    "embeddings = get_text_embeddings(prompt)\n",
    "image = denoise_with_text_embeddings(embeddings)\n",
    "scale_output(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "[x.shape for x in embeddings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import ops\n",
    "\n",
    "def slerp(t, v1, v2):\n",
    "    v1, v2 = ops.cast(v1, \"float32\"), ops.cast(v2, \"float32\")\n",
    "    v1_norm = ops.linalg.norm(ops.ravel(v1))\n",
    "    v2_norm = ops.linalg.norm(ops.ravel(v2))\n",
    "    dot = ops.sum(v1 * v2 / (v1_norm * v2_norm))\n",
    "    theta_0 = ops.arccos(dot)\n",
    "    sin_theta_0 = ops.sin(theta_0)\n",
    "    theta_t = theta_0 * t\n",
    "    sin_theta_t = ops.sin(theta_t)\n",
    "    s0 = ops.sin(theta_0 - theta_t) / sin_theta_0\n",
    "    s1 = sin_theta_t / sin_theta_0\n",
    "    return s0 * v1 + s1 * v2\n",
    "\n",
    "def interpolate_text_embeddings(e1, e2, start=0, stop=1, num=10):\n",
    "    embeddings = []\n",
    "    for t in np.linspace(start, stop, num):\n",
    "        embeddings.append(\n",
    "            (\n",
    "                slerp(t, e1[0], e2[0]),\n",
    "                e1[1],\n",
    "                slerp(t, e1[2], e2[2]),\n",
    "                e1[3],\n",
    "            )\n",
    "        )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "prompt1 = \"A friendly dog looking up in a field of flowers\"\n",
    "prompt2 = \"A horrifying, tentacled creature hovering over a field of flowers\"\n",
    "e1 = get_text_embeddings(prompt1)\n",
    "e2 = get_text_embeddings(prompt2)\n",
    "\n",
    "images = []\n",
    "for et in interpolate_text_embeddings(e1, e2, start=0.5, stop=0.6, num=9):\n",
    "    image = denoise_with_text_embeddings(et)\n",
    "    images.append(scale_output(image))\n",
    "display(images)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "chapter17_image-generation",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}